#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <cmath>                  // std::sqrt()
#include <iostream>

// jScience
#include "jScience/linalg.hpp"    // Matrix<>, Vector<>
#include "jrandnum.hpp"           // rand_num_normal_Mersenne_twister(), rand_num_uniform_Mersenne_twister()

// stats++
#include "statsxx/statistics.hpp" // statistics::sum()


//
// DESC: Creates a multilayer perception neural network, possibily with full connectivity and recurrent links.
//
inline void NEURAL_NET::create_MLP(
                                   int ni,
                                   int no,
                                   int nl,
                                   std::vector<int> nhn,
                                   bool fully_connect,
                                   bool recurrent,
                                   int  af_type, // 0 == logistic, 1 == tanh, 2 == softplus
                                   bool isClass,
                                   const std::vector<double> &w
                                   )
{
    if( ni < 1 )
    {
        std::cout << "Error in NEURAL_NET::init(): Attempting to initialize a NN with less than 1 input node." << std::endl;
        exit(0);
    }
    else if( no < 1 )
    {
        std::cout << "Error in NEURAL_NET::init(): Attempting to initialize a NN with less than 1 output node." << std::endl;
        exit(0);
    }
    else if( isClass && (no == 2) )
    {
        std::cout << "Warning in NEURAL_NET::init(): Attempting to initialize a NN for 1-of-n classification with 2 output nodes (binary classification). It is suggested, for such problems, to use 1 output node." << std::endl;
    }

    this->init();

    m_ninp = ni;
    m_nout = no;

    // calculate the total number of neurons (to allocate memory)
    // note: the `+1' is for bias
    int nn = ni + 1;

    for(int i = 0; i < nl; ++i)
    {
        nn += nhn[i];
    }

    nn += no;

    m_links_in.resize(  nn );
    m_links_out.resize( nn );

    m_isClassif = isClass;

    decltype(m_neurons.size()) counter0 = 0;
    decltype(m_neurons.size()) counter1 = 0;

    for(int i = 0; i < ni; ++i)
    {
        m_neurons.push_back( NEURON(i, NEURON::TYPE::INPUT, NEURON::ACTIVATION_FUNC::IDENTITY) );
        m_inp_neurons.push_back(i);
    }

    m_neurons.push_back( NEURON(ni, NEURON::TYPE::BIAS, NEURON::ACTIVATION_FUNC::IDENTITY) );
    m_bias_neurons.push_back(ni);

    for(int i = 0; i < nl; ++i)
    {
        counter1 = m_neurons.size();

        for(int j = 0; j < nhn[i]; ++j)
        {
            switch(af_type)
            {
            case 0:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::LOGISTIC) );
                break;
            case 1:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::TANH) );
                break;
            case 2:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::SOFTPLUS) );
                break;
            default:
                throw std::runtime_error("activation function type " + std::to_string(af_type) + " not recognized");
            }

            int startk = counter0;
            if( fully_connect )
            {
                startk = 0;
            }

            for( decltype(m_neurons.size()) k = startk; k < counter1; ++k )
            {
                m_links.push_back( LINK(k, (m_neurons.size()-1), 0.0, 0) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[k].push_back(m_links.size()-1);
            }

            // note: in the case of non-fully connected networks, the bias neuron is not hit after the first layer
            if( !fully_connect && (i > 0) )
            {
                m_links.push_back( LINK(m_bias_neurons[0], (m_neurons.size()-1), 0.0, 0) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[m_bias_neurons[0]].push_back(m_links.size()-1);
            }

            if( recurrent )
            {
                m_links.push_back( LINK((m_neurons.size()-1), (m_neurons.size()-1), 0.0, 1) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[m_neurons.size()-1].push_back(m_links.size()-1);
            }
        }

        counter0 = counter1;
    }

    counter1 = m_neurons.size();

    for(int i = 0; i < no; ++i)
    {
        // 1-OF-n (n > 1) CLASSIFICATION ...
        if( isClass && (no > 1) )
        {
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::SOFTMAX) );
        }
        // ... BINARY CLASSIFICATION ...
        else if(isClass)
        {
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::LOGISTIC) );
        }
        // ... REGRESSION
        else
        {
            // note: I believe that you want you want linear output nodes such that you have a matching loss function
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::IDENTITY) );
            //m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::LOGISTIC) );
        }

        m_out_neurons.push_back( m_neurons.size()-1 );

        int startk = counter0;
        if( fully_connect )
        {
            startk = 0;
        }

        for( decltype(m_neurons.size()) k = startk; k < counter1; ++k )
        {
            m_links.push_back( LINK(k, (m_neurons.size()-1), 0.0, 0) );
            m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
            m_links_out[k].push_back(m_links.size()-1);
        }

        // note: in the case of non-fully connected networks, the bias neuron is not hit if any hidden layers exist
        if( !fully_connect && (nl > 0) )
        {
            m_links.push_back( LINK(m_bias_neurons[0], (m_neurons.size()-1), 0.0, 0) );
            m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
            m_links_out[m_bias_neurons[0]].push_back(m_links.size()-1);
        }
    }


    // ***
    if(!w.empty())
    {
        std::cout << "m_links.size() : " << m_links.size() << '\n';
        std::cout << "w.size()       : " << w.size() << '\n';
    }
    // ***

    if(w.empty())
    {
        this->optimize_link_weights();
    }
    else
    {
        this->assign_weights(w);
    }

    this->create_lookup_tables();

    this->get_layer_neuron_idx();

    // ****
//    m_isTrained = true;
//    draw_NN(*this);
//    m_isTrained = false;
    // ****
}


//
// DESC: Creates a multilayer perception neural network, possibily with full connectivity and recurrent links.
//
inline void NEURAL_NET::create_MLP(
                                   std::vector<int> architecture,
                                   bool recurrent,
                                   int  af_type, // 0 == logistic, 1 == tanh, 2 == softplus
                                   bool isClass,
                                   const Matrix<double> &W,
                                   const Vector<double> &bias
                                   )
{
    this->init();

    this->m_ninp = architecture.front();
    this->m_nout = architecture.back();

    // note: when computing the number of neurons here, the +1 is for a bias
    int nn = statistics::sum(architecture) + 1;

    this->m_links_in.resize(  nn );
    this->m_links_out.resize( nn );

    this->m_isClassif = isClass;

    decltype(m_neurons.size()) counter0 = 0;
    decltype(m_neurons.size()) counter1 = 0;

    for(int i = 0; i < this->m_ninp; ++i)
    {
        m_neurons.push_back( NEURON(i, NEURON::TYPE::INPUT, NEURON::ACTIVATION_FUNC::IDENTITY) );
        m_inp_neurons.push_back(i);
    }

    m_neurons.push_back( NEURON(this->m_ninp, NEURON::TYPE::BIAS, NEURON::ACTIVATION_FUNC::IDENTITY) );
    m_bias_neurons.push_back(this->m_ninp);

    // for each of the hidden layers ...
    for(auto i = 1; i < (architecture.size()-1); ++i)
    {
//        std::cout << "hidden leayer " << i << '\n';

        counter1 = m_neurons.size();

        // for each of the neurons in this hidden layer ...
        for(int j = 0; j < architecture[i]; ++j)
        {
            switch(af_type)
            {
            case 0:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::LOGISTIC) );
                break;
            case 1:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::TANH) );
                break;
            case 2:
                m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::HIDDEN, NEURON::ACTIVATION_FUNC::SOFTPLUS) );
                break;
            default:
                throw std::runtime_error("activation function type " + std::to_string(af_type) + " not recognized");
            }

            int startk = counter0;

            for( decltype(m_neurons.size()) k = startk; k < counter1; ++k )
            {
                m_links.push_back( LINK(k, (m_neurons.size()-1), 0.0, 0) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[k].push_back(m_links.size()-1);

                // set the current "global" index of this neuron
                // note: this is always labeled one too high
                int idx2 = (m_neurons.size() - 1) - 1;

                // first check if this link is linking to the (single) hidden neuron
                if(k == this->m_ninp)
                {
                    m_links.back().weight = bias(idx2);

//                    std::cout << "bias(" << idx2 << "): " << bias(idx2) << '\n';

                    continue;
                }

                // determine the "global" index of the neuron this is connected to
                int idx1 = k;
                if(idx1 > this->m_ninp)
                {
                    --idx1;
                }

//                std::cout << "(" << idx1 << ") --> " << idx2 << ": " << W(idx2,idx1) << '\n';

                // note: recall that W is stored in lower-triangular form
                this->m_links.back().weight = W(idx2,idx1);
            }

            // note: in the case of non-fully connected networks, the bias neuron is not hit after the first layer
            if( i > 1 )
            {
                m_links.push_back( LINK(m_bias_neurons[0], (m_neurons.size()-1), 0.0, 0) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[m_bias_neurons[0]].push_back(m_links.size()-1);

                int idx2 = (m_neurons.size() - 1) - 1;
                this->m_links.back().weight = bias(idx2);

 //               std::cout << "bias(" << idx2 << "): " << bias(idx2) << '\n';
            }

            if( recurrent )
            {
                std::cout << "error in create_MLP(): cannot initialize a recurrent NN with a weight matrix" << '\n';

                m_links.push_back( LINK((m_neurons.size()-1), (m_neurons.size()-1), 0.0, 1) );
                m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
                m_links_out[m_neurons.size()-1].push_back(m_links.size()-1);
            }
        }

//        std::cout << '\n';

        counter0 = counter1;
    }

//    std::cout << '\n' << '\n';

    counter1 = m_neurons.size();

    for(int i = 0; i < this->m_nout; ++i)
    {
        // 1-OF-n (n > 1) CLASSIFICATION ...
        if( isClass && (this->m_nout > 1) )
        {
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::SOFTMAX) );
        }
        // ... BINARY CLASSIFICATION ...
        else if(isClass)
        {
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::LOGISTIC) );
        }
        // ... REGRESSION
        else
        {
            // note: I believe that you want you want linear output nodes such that you have a matching loss function
            m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::IDENTITY) );
            //m_neurons.push_back( NEURON(m_neurons.size(), NEURON::TYPE::OUTPUT, NEURON::ACTIVATION_FUNC::LOGISTIC) );
        }

        m_out_neurons.push_back( m_neurons.size()-1 );

        int startk = counter0;

        for( decltype(m_neurons.size()) k = startk; k < counter1; ++k )
        {
            m_links.push_back( LINK(k, (m_neurons.size()-1), 0.0, 0) );
            m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
            m_links_out[k].push_back(m_links.size()-1);

            // set the current "global" index of this neuron
            // note: this is always labeled one too high
            int idx2 = (m_neurons.size() - 1) - 1;

            // first check if this link is linking to the (single) hidden neuron
            if(k == this->m_ninp)
            {
                m_links.back().weight = bias(idx2);

//                std::cout << "bias(" << idx2 << "): " << bias(idx2) << '\n';

                continue;
            }

            // determine the "global" index of the neuron this is connected to
            int idx1 = k;
            if(idx1 > this->m_ninp)
            {
                --idx1;
            }

            // note: recall that W is stored in lower-triangular form
            this->m_links.back().weight = W(idx2,idx1);
        }

        // note: in the case of non-fully connected networks, the bias neuron is not hit if any hidden layers exist
        if( architecture.size() > 2 )
        {
            m_links.push_back( LINK(m_bias_neurons[0], (m_neurons.size()-1), 0.0, 0) );
            m_links_in[m_neurons.size()-1].push_back(m_links.size()-1);
            m_links_out[m_bias_neurons[0]].push_back(m_links.size()-1);

            int idx2 = (m_neurons.size() - 1) - 1;
            this->m_links.back().weight = bias(idx2);

//            std::cout << "bias(" << idx2 << "): " << bias(idx2) << '\n';
        }
    }

    this->create_lookup_tables();

    this->get_layer_neuron_idx();

/*
    for(auto i = 0; i < m_links.size(); ++i)
    {
        std::cout << "m_links[" << i << "].weight: " << m_links[i].weight << '\n';
    }
*/

    // ****
//    m_isTrained = true;
//    draw_NN(*this);
//    m_isTrained = false;
    // ****
}


//========================================================================
//========================================================================
//
// NAME: void NEURAL_NET::optimize_link_weights()
//
// DESC: Calculates the optimal initial weights in a network.
//
// NOTES:
//     ! This algorithm assumes the following:
//          - The input data has been normalized such that each input has a covariance of 1
//          - LeCun tanh activation functions are used (they have an effective gain of 1)
//          - It is reasonable to set the covariance of the total input and output of each node to 1
//     ! See:
//          http://wiki.analyticsxx.com/view/Initial_weight_selection
//
//========================================================================
//========================================================================
inline void NEURAL_NET::optimize_link_weights()
{
    for( decltype(m_neurons.size()) k = 0; k < m_neurons.size(); ++k )
    {
        //std::cout << "m_links_in[k].size(): " << m_links_in[k].size() << '\n';

        double alpha = std::sqrt( 3.0/static_cast<double>(m_links_in[k].size()) ); // uniform distribution
//        double alpha = std::sqrt( 1.0/static_cast<double>(m_links_in[k].size()) ); // normal distribution

        for( auto l : m_links_in[k] )
        {
            m_links[l].weight = rand_num_uniform_Mersenne_twister(-alpha, alpha);
//            m_links[l].weight = rand_num_normal_Mersenne_twister(0.0, alpha);
        }
    }
}


